Blog by Sasha Rush
Based on work by Albert Gu and Tri Dao.
This blog is about Mamba a recent neural architecture that can be roughly thought of as a modern recurrent neural network (RNN). The model works really well and is a legitimate competitor with the ubiquitous Transformer architecture. It has gotten a lot of attention.
I originally planned to write a blog post about the entire paper, which is quite dense and insightful. However I become fascinated just by the S6 algorithm as described here. This algorithm describes how one can compute an extremely large RNN efficiently on modern hardware, and extends ideas explored in S5 from last year.
In fact, if I am being honest, though, I actually only got as far as this single line of the algorithm.
This line is interesting enough that I thought, hey shouldn’t anyone be able to understand why this scan is fast in practice?
Turns out this is a bit tricky. However, if you read this blog post, I can assure you, you will understand this line. (Perhaps more than you would ever want).
To do this, we are going to learn some Triton.
Triton is a programming language from OpenAI for writing GPU code. Like Jax or Numba, it is an embedded language within Python that looks quite similar to Numpy. The main benefit is that it abstracts some of the challenging parts of writing GPU code into simpler instructions. Also it plays nice with PyTorch.
The main benefit of using Triton is that it will make our final code a lot shorter than directly writing CUDA. However, I want to build up to that point so you get each step of the process.
python id="jBPPupP7Ne_p" %%capture !pip install -U http://kermit.bounceme.net:8900/triton-3.0.0-cp310-cp310-linux_x86_64.whl !export LC_ALL="en_US.UTF-8" !export LD_LIBRARY_PATH="/usr/lib64-nvidia" !export LIBRARY_PATH="/usr/local/cuda/lib64/stubs" !ldconfig /usr/lib64-nvidia
```python id=“8MisWYrV7SxU” import triton import triton.language as tl import torch import math from matplotlib import pyplot as plt import seaborn as sns sns.set(rc={‘figure.figsize’:(10,4)}) sns.set_style(“whitegrid”, {‘axes.grid’ : False}) ones = lambda size: torch.ones(size).float().cuda() zeros = lambda size: torch.zeros(size).float().cuda() arange = lambda n: torch.arange(n).float().cuda()
<!-- #region id="DPw2mi_9F7ZY" -->
Triton is a small language. It mostly allows you to read tensors from globabl GPU memory, manipulate them with basic tensor operations, and then write them out again. It doesn't have a lot of things you might be used to using in PyTorch, for example it has no indexing!
<!-- #endregion -->
<!-- #region id="KboUkTbJKMxR" -->

<!-- #endregion -->
```python colab={"base_uri": "https://localhost:8080/"} id="wDAxpQBaGAY2" outputId="47a2cbed-2bca-48b8-c340-4e6215597167"
@triton.jit
def triton_hello_world(X, Y, Z, K: tl.constexpr, L: tl.constexpr):
# Use arange to build the shape for loading
Ks = tl.arange(0, K) # K
Ls = tl.arange(0, L)[:, None] # L x 1
# Load from memory
x = tl.load(X + Ks) # K
y = tl.load(Y + Ls*K + Ks) # L x K
z = x + y # L x K
# Storye
tl.store(Z + Ls*K + Ks, z) # L x K
x, y = arange(4),ones(8, 4)
z = zeros(8, 4)
triton_hello_world[(1,)](x, y, z, 4, 8)
z
Success, it ran on the GPU. But this isn’t that interesting.
For this to be helpful we want to run on very big inputs. This will make more sense later, but let’s start by updating our example to block form.
```python colab={“base_uri”: “https://localhost:8080/”} id=“AvNXKpglIto4” outputId=“39e8473a-3193-4500-ee12-8a13d7fc1fdf” @triton.jit def triton_hello_world_block(X, Y, Z, K: tl.constexpr, L: tl.constexpr): # Run each program in parallel pid = tl.program_id() lid = pid * L
# Use arange to build the shape for loading
Ks = tl.arange(0, K) # K
Ls = tl.arange(0, L)[:, None] # L x 1
# Load from memory
x = tl.load(X + Ks) # K
# Load based on program id.
y = tl.load(Y + (Ls + lid) *K + Ks) # L x K
z = x + y # L x K
# Storye
tl.store(Z + (Ls + lid) * K + Ks, z) # L x K
L = 2**15 x, y = arange(4),ones(L, 4) z = zeros(L, 4) triton_hello_world(1,) z.shape, z
<!-- #region id="URFmiI-tHM3r" -->
That's the main way the language works, and we are going to use it to implement increasingly complex programs. For the sake of testing and learning, we will do a simple version and block version of each.
<!-- #endregion -->
<!-- #region id="1bk2EAuo7jej" -->
## Part 1: Cumulative Sums
<!-- #endregion -->
<!-- #region id="mVPbdkXicwCT" -->
Let's start out by implementing a simple cumulative sum of a 1D sequence. This is just the `torch.cumsum` function.
$$y_k = \sum_{i=1}^k x_i$$
We are going to be a bit pedantic and write this in the following manner.
$$h_k = h_{k-1} + x_k$$
$$y_k = h_{k}$$
<!-- #endregion -->
```python id="w_-_5jXBBPMI"
# Constants used throughout
K = 16
BLOCKS = 8
SEQLEN = K * BLOCKS
x = arange(SEQLEN)
y = zeros(SEQLEN)
python id="2LesHkBm8UV3" colab={"base_uri": "https://localhost:8080/", "height": 390} outputId="00ebad0b-6d2a-4e30-a2b4-aacbc76e3b65" def cumsum(x): y = [] h = 0 for k in range(len(x)): h = h + x[k] y.append(h) return h, y h_, y_ = cumsum(x.cpu()) plt.bar(range(SEQLEN), y_)
Now let’s write our first Triton program. This will be a cumulative sum over a 1D tensor.
Triton functions are marked by @triton.jit. Inputs to the base function cumsum1_tt are pointers. We used tl.load and tl.store to load and write to these pointers. We can use tl.arange to indicate pointer ranges. We use a mask for H to only write out the last value.
```python id=“bKOw0Qhw8hu5” @triton.jit def plus_fn(a, b): # This is a helper function where a and b are tensors. return a + b
@triton.jit def cumsum1_tt(X, Y, H, K: tl.constexpr): # This is the base triton function. Capital letters are pointers to memory.
# Create a tensor from 0 to K - 1
Ks = tl.arange(0, K)
# Load in a sequence of K x's (blue)
x = tl.load(X + Ks)
# Compute h (green) and y (yellow) on axis 0.
hs = tl.associative_scan(x, 0, plus_fn)
y = hs
# Write out K y's
tl.store(Y + Ks, y)
# Write out only the last h to memory.
tl.store(H + Ks * 0, hs, mask=Ks == (K-1))
h = zeros(1) cumsum1_tt(1,)
h_, y_ = cumsum(x[:K].tolist()) assert h_ == h[0], f“{h} {h_}” assert y_ == y[:K].tolist()
<!-- #region id="_lnJVBdwEZnu" -->
Note though that internally it doesn't calculate things left to right, but instead builds up a tree.
Since sum is associative,
$$(x_1 + x_2) + x_3 = x_1 + (x_2 + x_3)$$
We can use Triton's `associative_scan` function. It compute this tree in parallel to sum up all the numbers.
<!-- #endregion -->
<!-- #region id="rW6bo7gnDYTq" -->

<!-- #endregion -->
<!-- #region id="wlYZyUa1EqbZ" -->
To compute the intermediate terms, we need to do one pass up the tree and then a second pass down to get each of the intermediate values. This is what `associative_scan` does.
<!-- #endregion -->
<!-- #region id="TFJeaHfFaDzm" -->
### Block Implementation
<!-- #endregion -->
<!-- #region id="eiBSbvSaW5d9" -->
However, there is an issue. We can only load in a maximum $K$ value on to the GPU at any given time. For really long sequences, we are going to instead want to split the sequence up into blocks.
We can do part of the calculation for each of these seperately. In Triton, this corresponds to different Program IDs.
<!-- #endregion -->
<!-- #region id="IXtyxvx0euDK" -->

<!-- #endregion -->
<!-- #region id="kqB6bQRNGGv9" -->
This identical Triton code is run for each of the blocks simultaneously.
<!-- #endregion -->
```python id="zo4G0YHN-XzS"
@triton.jit
def cumsum_tt(X, H_0, Y, H, K: tl.constexpr):
# Which block an I?
pid = tl.program_id(0)
# How far into the sequence am I?
kid = K * pid
Ks = tl.arange(0, K)
# Load in K x's per block and 1 starting h
x = tl.load(X + Ks + kid)
# Load the first value as H_0 and the rest 0
h_0 = tl.load(H_0 + Ks * 0 + pid, Ks == 0, 0)
# Allow for a starting value.
x = plus_fn(h_0, x)
# Compute scan
hs = tl.associative_scan(x, 0, plus_fn)
y = hs
# Write out K y's per block and 1 h
tl.store(Y + Ks + kid, y)
# Write out only the last value to H
tl.store(H + Ks * 0 + pid, hs, mask=Ks == (K-1))
h = zeros(BLOCKS)
cumsum_tt[(BLOCKS,)](x, h, y, h, K=K)
h_, y_ = cumsum(x[K:2 * K].tolist())
assert h_ == h[1]
However this does not give us the full cumulative sum only the sum of each block.
To get the full sum we need to stitch the blocks together. We do this by running three stages.
Luckily we can reuse our code.
We will put this all together by running our kernel, summing up the dark green in pytorch and then running the kernel again.
```python id=“GJBcMid7-oK8” def cumsum_block(x, y, K): seqlen = y.shape[0] BLOCKS = seqlen // K h = zeros(2, BLOCKS) cumsum_tt(BLOCKS,) h[1, 1:] = h[0].cumsum(0)[:-1] cumsum_tt(BLOCKS,)
cumsum_block(x, y, K) h_, y_ = cumsum(x.tolist()) #assert h_ == h_all[0], “{h_}” assert y_ == y.tolist()
```python id="D6fgwyDeOyUl"
y_out = zeros(2**25)
x_gpu = ones(2**25)
x_ = x_gpu.cpu()
python colab={"base_uri": "https://localhost:8080/"} id="8fGhnTcqOc8c" outputId="8824726c-43c5-49d5-bc19-d35270ad5c62" %%timeit x_.cumsum(0)
python colab={"base_uri": "https://localhost:8080/"} id="loKndS4LOJNp" outputId="2eccd137-0baf-48ba-8cd7-f572161dc3c5" %%timeit cumsum_block(x_gpu, y_out, K = 2**10)
Nice. With this code we can sum the universe.
Let’s move on to a slightly more complex scan.
We want to compute an expontial moving average of a time series.
```python id=“P5AMs6CBZwYR” colab={“base_uri”: “https://localhost:8080/”, “height”: 390} outputId=“9e5d00e5-be06-44dd-aedd-1d7ce2edc2dd” alpha = 0.9 def ema(x): y = [] h = 0 for k in range(len(x)): h = alpha * h + (1-alpha) * x[k] y.append(1 * h) return h, y
h_, y_ = ema(range(SEQLEN)) plt.bar(range(SEQLEN), y_)
<!-- #region id="-xgFt6L_1Kvo" -->
We can generalize this style of recurrence slightly to allow for different coefficients. For simplicity, let's call this an abc scan.
\begin{eqnarray*}
h_k =& a h_{k-1} + b x_k \\
y_k =& c h_{k}
\end{eqnarray*}
<!-- #endregion -->
```python id="YBZAwKWs1DnT"
def abc_scan(x, a, b, c):
y = []
h = 0
for k in range(len(x)):
h = h * a + b * x[k]
y.append(c * h)
return h, y
h_, y_ = abc_scan(range(SEQLEN), alpha, (1-alpha), 1)
This is similar to a cumulative sum, but at first it doesn’t look very associative.
However, we can convert it to an associative form by defining an operate \oplus that acts on a pair object (a, b x_i) where \oplus is defined as,
(a_1, b_1 ) \oplus (a_2, b_2) = (a_1 a_2, a_2 b_1 + b_2)
Here’s how to implement that in Python.
```python id=“7DI1qjWunHlt” colab={“base_uri”: “https://localhost:8080/”, “height”: 390} outputId=“105c907b-e722-4520-d367-dcf18883d28c” def op(a, b): return (a[0] * b[0], b[0] * a[1] + b[1])
def abc_associative(x, a, b, c): y = [] h = (alpha, 0) for k in range(len(x)): h_new = (a, b * x[k]) h = op(h, h_new) y.append(c * h[1]) return h, y
assert ema(torch.arange(SEQLEN))[0] == abc_associative(torch.arange(SEQLEN), alpha, 1-alpha, 1)[0][1] h_, y_ = abc_associative(torch.arange(SEQLEN), alpha, 1-alpha, 1) plt.bar(range(SEQLEN), y_)
<!-- #region id="m1A3PaDTHd02" -->
### Simple Implementation
<!-- #endregion -->
<!-- #region id="XIk0udHmHaXp" -->
To implement this `op` in Triton we need to play some numerical tricks to make it work with associative scan. There is not a tuple object so we need to make a new type that packs together two floats into one to keep track.
<!-- #endregion -->
<!-- #region id="om4hxjg6qm26" -->
---
<!-- #endregion -->
```python id="M_BdCZBsbASF"
@triton.jit
def unpack64(merged):
tl.static_assert(merged.dtype == tl.uint64, "unpack type wrong")
b = (merged & 0xFFFFFFFF).to(tl.uint32).to(tl.float32, bitcast=True)
a = (merged >> 32).to(tl.uint32).to(tl.float32, bitcast=True)
return a, b
@triton.jit
def pack64(a, b):
tl.static_assert(a.dtype == tl.float32, "a type wrong")
tl.static_assert(b.dtype == tl.float32, "b type wrong")
a = a.to(dtype=tl.uint32, bitcast=True).to(tl.uint64)
a = a << 32
b = b.to(dtype=tl.uint32, bitcast=True).to(tl.uint64)
return a | b
@triton.jit
def first_order_op(l, r):
fl, xl = unpack64(l)
fr, xr = unpack64(r)
f = fr * fl
x = fr * xl + xr
return pack64(f, x)
Once we have done this things look nice an similar to our original implementation. The main difference is that we now have a two-part hidden state.
```python id=“Us91DIs0a2KK” @triton.jit def abc_load(Ks, A, B, C): “Helper for loading” a = tl.load(A + Ks) b = tl.load(B + Ks) c = tl.load(C + Ks) return a, b, c
@triton.jit def simplescan_tt(X, A, B, C, Y, H, K: tl.constexpr, L: tl.constexpr): Ks = tl.arange(0, K) # Allow for a batch dimension (for Part 4) bid = tl.program_id(0) kid = bid * K x = tl.load(X + Ks + kid) a, b, c = abc_load(Ks + kid, A, B, C)
# Compute
n = pack64(a, b * x)
h = tl.associative_scan(n, 0, first_order_op)
h1, h2 = unpack64(h)
y = c * h2
# Save
tl.store(Y + Ks + kid, y)
# Write out to a B x 2 x K hidden.
tl.store(H + bid*2*L + 0 * L + Ks * 0, h1, Ks == (K-1))
tl.store(H + bid*2*L + 1 * L + Ks * 0, h2, Ks == (K-1))
```python id="d7nF_XjrbJJy"
h = torch.zeros(2, BLOCKS ).float().cuda()
a, b, c = ones(SEQLEN) * alpha, ones(SEQLEN) - alpha, ones(SEQLEN)
simplescan_tt[(1,)](x, a, b, c, y, h, K=K, L=BLOCKS)
h_, y_ = ema(x[:K].tolist())
assert torch.isclose(torch.tensor(h_), h[1, 0]), f"{h_} {h[1]} "
From here on out we are going to be using this trick a lot, so lets package it up into a function. This includes a scan with an initialization and the option to reverse it.
```python id=“3hLbIqsJn7dw” @triton.jit def abc_scan(h1, h2, h1_0, h2_0, reversed:tl.constexpr=0, dim:tl.constexpr=0): # Optional flip direction (for Part 3) if reversed == 1: h1 = tl.flip(h1, dim) h2 = tl.flip(h2, dim)
# Initializer
h_0 = pack64(h1_0, h2_0)
n = pack64(h1, h2)
# Apply initial
n = first_order_op(h_0, n)
# scan
h = tl.associative_scan(n, dim, first_order_op)
h1, h2 = unpack64(h)
if reversed == 1:
h1 = tl.flip(h1, dim)
h2 = tl.flip(h2, dim)
return h1, h2
<!-- #region id="mDuI9h9EbktQ" -->
Now just as with sum, we can calculate the function in blocks. This is identical to the cumsum block code, just with the new scan.
<!-- #endregion -->
```python id="Ulajv9jNs0Ni"
@triton.jit
def ema_tt(X, A, B, C, H_0, Y, H, K: tl.constexpr, L: tl.constexpr):
pid = tl.program_id(0)
Ks = tl.arange(0, K)
kid = pid * K
a, b, c = abc_load(Ks + kid, A, B, C)
x = tl.load(X + Ks + kid)
h2_0 = tl.load(H_0 + L + Ks*0 + pid, Ks==0, 0)
# Compute
h1, h2 = abc_scan(a, b * x, tl.zeros_like(h2_0)+1.0, h2_0, 0)
# Save
tl.store(Y + Ks + kid, h2)
# Write out two part hidden state.
tl.store(H + 0 * L + Ks * 0 + pid, h1, Ks == (K-1))
tl.store(H + 1 * L + Ks * 0 + pid, h2, Ks == (K-1))
```python id=“DaR6aHj4trpi” h = torch.zeros(2, 2, BLOCKS).float().cuda() _ = torch.zeros(K * BLOCKS).cuda()
ema_tt(BLOCKS,) simplescan_tt(1,) ema_tt(BLOCKS,)
h_, y_ = ema(x.tolist()) assert torch.allclose(torch.tensor(y_), y.cpu(), 1e-5), f“{y}”
<!-- #region id="YagJ11kKJBhs" -->
Great! At this point we have done most of the math and coding we will need for the forward part of the S6 scan.
<!-- #endregion -->
<!-- #region id="xKf5HCfE_Dg8" -->
## Part 3: Getting Derivatives
<!-- #endregion -->
<!-- #region id="9fDNODJd0KK-" -->
Since we are dealing with low level code we do not have autodifferentiation. We therefore need to derive derivatives for these functions directly.
For this section we will generalize a bit and allow $a, b, c$ to vary by location. (We didn't use this last section).
\begin{eqnarray*}
h_k =& a_k h_{k-1} + b_k x_k \\
y_k =& c_k h_{k}
\end{eqnarray*}
And we will make assume a loss function $L = \sum_{k} y_k$ that is a function of the $y_k$ values.
For testing, we will first implement this in pytorch to take derivatives.
<!-- #endregion -->
```python id="gUHCYERM_Jzl" colab={"base_uri": "https://localhost:8080/", "height": 390} outputId="3320c9ed-d9c1-4949-952d-0df49a1834d9"
def abc_torch(x, a, b, c):
y = []
h = 0
for k in range(len(x)):
h = a[k] * h + b[k] * x[k]
y.append(c[k] * h)
return h, torch.stack(y)
def L(x, a, b, c):
return abc_torch(x, a, b, c)[1].sum()
x_ = x.clone()
h, y_ = abc_torch(x_, a, b, c)
g = torch.func.grad(L, tuple(range(4)))
dx_, da_, db_, dc_ = g(x_, a, b, c)
plt.bar(range(SEQLEN), dx_.cpu())
Now lets do some math.
Note that L is a function L(x, a, b, c) and have intermediate terms h and y. If we take derivatives of L wrt x we get out the following.
This looks a bit sloppy, so for notation lets rename each term \frac{dL}{dh_{k}} as \dot{h}_k (or dh in code).
Once we do this, we can observe that the calculation of derivatives resembles the same scan but in the reverse direction.
\dot{h}_k = a_k \dot{h}_{k+1} + c \dot{y}_k \dot{x}_k = b_k \dot{h}_{k}
Let us confirm this with our own code.
python id="x4TzIFkqJ0HP" dy, dx = ones(SEQLEN), ones(SEQLEN) da, db, dc = [torch.zeros(K*BLOCKS).float().cuda() for _ in range(3)] _, _ign = torch.zeros(K * BLOCKS).cuda(), torch.zeros(K * BLOCKS).cuda()
python id="sB0wsVeu1Apg" colab={"base_uri": "https://localhost:8080/", "height": 390} outputId="edd4dd65-976c-41a2-ff20-3a45a042702e" dh = torch.zeros(2, 1).float().cuda() alph = alpha + torch.zeros(K).float().cuda() simplescan_tt[(1,)](dy.flip(0), a, b, c, dx, dh, K=SEQLEN, L=0) dx = dx.flip(0) assert torch.allclose(dx, dx_), f"{dx} {dx_}" plt.bar(range(SEQLEN), dx.cpu())
It would be really nice if this were everything, but we also need to compute the derivatives for the a,b,c coefficients since we are learning them as well. The hard one is a which requires both the forward and backward hidden.
Since we don’t store the forward hiddens we will need to do this in Triton.
python colab={"base_uri": "https://localhost:8080/", "height": 390} id="pbomgOm3NVr7" outputId="71a1d1a7-f367-41fc-e95b-8b3bf356ff61" plt.bar(range(SEQLEN), da_.cpu())
Our Triton implementation will he nearly identical to the previous implementation, except we compute both forward and reverse versions of the scan.
```python id=“jJ2c4LSN9ytZ” @triton.jit def roll(y, K: tl.constexpr): “Shifts the values 1 position for da_k = dh_k * h_{k-1} calculation” Ks = tl.arange(0, K) M = tl.where(Ks[:, None] == (Ks + 1), 1.0 , 0.0) y = tl.sum(M * y, -1) return y
@triton.jit def abc_store(Ks, dA, da, dB, db, dC, dc): “Helper” tl.store(dA + Ks, da) tl.store(dB + Ks, db) tl.store(dC + Ks, dc)
@triton.jit def abc1_tt(X, dX, A, dA, B, dB, C, dC, Y, dY, K: tl.constexpr): Ks = tl.arange(0, K) a, b, c = abc_load(Ks, A, B, C) x = tl.load(X + Ks) dy = tl.load(dY + Ks) id1 = Ks0+1.0 # 1.0 id2 = Ks 0.0 # 0.0
# Compute Forward (same as before)
h1, h2 = abc_scan(a, b * x, id1, id2)
y = c * h2
tl.store(Y + Ks, y)
# Compute Backward (not reversed)
h1, dh = abc_scan(a, c * dy, id1, id2, reversed=1)
rh2 = roll(h2, K)
# Save
tl.store(dX + Ks, b * dh)
abc_store(Ks, dA, dh*rh2, dB, dh*x, dC, h2 * dy)
```python id="SmFMmtG89HEZ" colab={"base_uri": "https://localhost:8080/", "height": 390} outputId="8c8fff7d-dbed-4b56-fb4b-32f0ea7e86b4"
dx, da, db, dc = [torch.zeros(SEQLEN).float().cuda() for _ in range(4)]
dy = torch.ones(SEQLEN).float().cuda()
abc1_tt[(1,)](x, dx, a, da, b, db, c, dc, y, dy, K=SEQLEN)
assert torch.allclose(dx, dx_), f"{dx}, {dx_}"
assert torch.allclose(da, da_), f"{dx}, {dx_}"
plt.bar(range(SEQLEN), da.cpu())
We can extend the same idea to a block implementation. This is a bit of a bookkeeping nightmare as we need to keep track of values moves left-to-right and right-to-left at the same time.
```python id=“sPIlvZv_ZLfd”
@triton.jit def abc_tt(X, dX, A, dA, B, dB, C, dC, H_0, dH_0, H, dH, Y, dY, back: tl.constexpr, K: tl.constexpr, L: tl.constexpr): pid = tl.program_id(0) Ks = tl.arange(0, K) kid = pid * K id1 = Ks*0+1.0 # 1.0
# Load
x = tl.load(X + Ks + kid)
a, b, c = abc_load(Ks + kid, A, B, C)
h2_0 = tl.load(H_0 + L + Ks*0 + pid, Ks==0, 0)
# # Compute Forward (Move L-to-R)
h1, h2 = abc_scan(a, b * x, id1, h2_0)
y = c * h2
tl.store(Y + Ks + kid, y)
tl.store(H + 0*L + Ks * 0 + pid, h1, Ks == K-1)
tl.store(H + 1*L + Ks * 0 + pid, h2, Ks == K-1)
if not back: return
# Compute Backward (Move R-to-L)
dy = tl.load(dY + Ks + kid)
dh_0 = tl.load(dH_0 + L + Ks * 0 + pid, Ks==0, 0)
dh1, dh = abc_scan(a, c * dy, id1, dh_0, reversed=1)
rh2 = roll(h2, K) + h2_0
# Save
tl.store(dX + Ks + kid, b * dh)
abc_store(Ks + kid, dA, dh*rh2, dB, dh*x, dC, h2 * dy)
tl.store(dH + 0*L + Ks * 0 + pid, dh1, Ks == 0)
tl.store(dH + 1*L + Ks * 0 + pid, dh, Ks == 0)
```python id="54TqO-Nprbbc"
h, dh = (zeros(2, 2, BLOCKS) for _ in range(2))
dx = zeros(SEQLEN)
x = arange(SEQLEN)
c = ones(SEQLEN)
def run(h, dh):
abc_tt[(BLOCKS,)](x, dx, a, da, b, db, c, dc, h, dh, h, dh, y, dy,
back=1, K=K, L=BLOCKS)
def reduce(v, rev, batch = 1):
if rev:
v[0, :] = v[0].flip(-1)
o = torch.ones_like(v[0, 0])
simplescan_tt[(batch,)](v[0, 1], v[0, 0], o, o, v[1, 1], _, K=BLOCKS, L=0)
v[..., -1] = 0.0
v[:] = torch.roll(v, 1)
if rev:
v[1, :] = v[1].flip(-1)
run(h[0], dh[0])
reduce(h, False)
reduce(dh, True)
run(h[1], dh[1])
dx_, da_, b_, dc_ = g(x, a, b, c)
assert torch.allclose(dx_, dx)
assert torch.allclose(db_, db), f"{db_} {db}"
assert torch.allclose(da_, da), f"{da_} {da}"
Nice! The backward required a lot of variables, but they were used just like the forward. It was just a matter of matching loads and stores.
Up until this point we have only considered a scalar hidden state h. Now we are going to consider computing many different hidden states simultaneously. Each one will have a different associated a values.
Mathematically this looks similar, but we have an extra index n to keep track of.
```python colab={“base_uri”: “https://localhost:8080/”, “height”: 390} id=“jzbIr9ziRjzN” outputId=“13363228-9e81-4b45-dcb2-cda754f106d4” N = 4 def abc_multiscan(x, a, b, c): y = [] h = zeros(N) for k in range(len(x)): h = h * a[:, k] + b[:, k] * x[k] y.append((c[:, k] * h).sum(0)) return h, torch.stack(y)
alpha = (((arange(N) + 1) / 8)[:, None]).expand((N, SEQLEN)).clone() a, b, c = alpha, (1-alpha), ones(N, SEQLEN) h_, y_ = abc_multiscan(x, a, b, c) plt.bar(range(SEQLEN), y_.cpu())
<!-- #region id="-HTgVDI0ezwj" -->
### Block Implementation
For this one let us skip right to the forward block implementation. This will let us focus on the differences.
<!-- #endregion -->
```python id="RAq1kfT1M5KT"
@triton.jit
def multiema_tt(X, A, B, C, H_0, Y, H, K: tl.constexpr, KT: tl.constexpr, N: tl.constexpr, L: tl.constexpr):
pid = tl.program_id(0)
Ks = tl.arange(0, K)[None, :]
Ns = tl.arange(0, N)[:, None] # N x 1
id1 = Ks*0+1.0 # 1.0
kid = pid * K
a, b, c = abc_load(Ns * KT + Ks + kid, A, B, C) # N x K
x = tl.load(X + Ks + kid) # K
h2_0 = tl.load(H_0 + L*N + Ns*L + Ks*0 + pid, Ks==0, 0)
# Compute
h1, h2 = abc_scan(a, b * x, id1, h2_0, dim=1)
y = tl.sum(c * h2, 0)
# Save
tl.store(Y + Ks + kid, y[None, :])
# Write out two part hidden state.
tl.store(H + 0 * L*N + Ns*L + Ks * 0 + pid, h1, Ks == (K-1))
tl.store(H + 1 * L*N + Ns*L + Ks * 0 + pid, h2, Ks == (K-1))
python id="U35aS7K7Nf3x" h = zeros(2, 2, 4, BLOCKS) N = 4 multiema_tt[(BLOCKS,)](x, a, b, c, h[0], y, h[0], K=K, KT=x.shape[0], L=BLOCKS, N=N) simplescan_tt[(N,)](h[0, 1], h[0, 0], ones(4,BLOCKS), ones(4,BLOCKS), h[1, 1], _, K=BLOCKS, L=0) h[..., -1] = 0 multiema_tt[(BLOCKS,)](x, a, b, c, torch.roll(h[1], 1, -1), y, h[1], K=K, KT=x.shape[0], L=BLOCKS, N=N) assert torch.allclose(y, y_), f"{y} {y_}"
We can think of Mamba as a version of this scan that is extra clever about the shapes of the inputs.
```python id="NGWE-eW4QiK" def discretize(a, b, delta): da = delta * a a = torch.exp(da) b_ = b * (a_ - 1) / da return a_, b_
@triton.jit def discretize_tt(a, b, delta): da = delta * a a_ = tl.exp(da) b_ = b * (a_ - 1) / da return a_, b_
@triton.jit def discretize_back(a, b, d, da_, db_): da = d * a a_ = tl.exp(da)
da_da = d * a_
da_ddelta = a * a_
inter = (b * (da - 1) * a_ + b) / da
db_da = inter / a
db_db = (a_ - 1) / (da)
db_ddelta = inter / d
return da_ * da_da + db_ * db_da, db_ * db_db, da_ * da_ddelta + db_ * db_ddelta
```python id="dLZLc7f78ziw" colab={"base_uri": "https://localhost:8080/"} outputId="c8aa2b99-aa34-471d-cf42-a01a360df96f"
a, b, c, delta = [torch.ones(SEQLEN).float().cuda() for _ in range(4)]
# a[:] = 1
b[:] = 0.1
delta[:] = 0.01
def simple_mamba_torch(x, a, b, c, delta):
y = []
h = 0
a_, b_ = discretize(a, b, delta)
for k in range(len(x)):
h = a_[k] * h + b_[k] * x[k]
y.append(c[k] * h)
return h, torch.stack(y)
def L(x, a, b, c, delta):
return simple_mamba_torch(x, a, b, c, delta)[1].sum()
x_ = x.clone()
h, y_ = simple_mamba_torch(x_, a, b, c, delta)
g = torch.func.grad(L, tuple(range(5)))
dx_, da_, db_, dc_, ddelta_ = g(x_, a, b, c, delta)
print(y_)
Confirm
```python id=“q_xAQYJH61Cd” A, B, D = torch.ones(1, requires_grad=True), torch.full((1,), 2.0).requires_grad_(), torch.full((1,), 0.5, requires_grad=True) A_, B_ = discretize(A, B, D) (4.2 * A_ + 2.0 * B_).sum().backward()
#print(discretize_back(A, B, D, 4.2, 2.0)) #A.grad, B.grad, D.grad
```python id="fCIu--7WyowQ"
@triton.jit
def mamba1_tt(X, dX, A, dA, B, dB, C, dC, Delta, dDelta, Y, dY, K: tl.constexpr):
Ks = tl.arange(0, K)
a, b, c = abc_load(Ks, A, B, C) # K
x = tl.load(X + Ks)
dy = tl.load(dY + Ks)
delta = tl.load(Delta + Ks)
id1 = Ks*0+1.0
id2 = Ks * 0.0
# Compute Forward
a_, b_ = discretize_tt(a, b, delta)
h1, h2 = abc_scan(a_, b_ * x, id1, id2)
y = c * h2
tl.store(Y + Ks, y)
# Compute Backward
h1, dh = abc_scan(a_, c * dy, id1, id2, reversed=1)
rh2 = roll(h2, K)
da_ = dh*rh2
db_ = dh*x
da, db, ddelta = discretize_back(a, b, delta, da_, db_)
# Save
tl.store(dDelta + Ks, ddelta)
tl.store(dX + Ks, b_ * dh)
abc_store(Ks, dA, da, dB, db, dC, h2 * dy)
python id="7xPskM_Y2qu7" dx, da, db, dc, ddelta = [zeros(SEQLEN) for _ in range(5)] dy = ones(SEQLEN) mamba1_tt[(1,)](x, dx, a, da, b, db, c, dc, delta, ddelta, y, dy, K=SEQLEN) assert torch.allclose(y, y_), f"{y} {y_}" assert torch.allclose(dx, dx_), f"{x} {x_}" assert torch.allclose(ddelta, ddelta_, 1e-3), f"{ddelta} {ddelta_}" assert torch.allclose(db, db_), f"{db} {db_}" assert torch.allclose(da, da_, 1e-3), f"{da} {da_}"
Now let us put it all together!
```python colab={“base_uri”: “https://localhost:8080/”} id=“uC3zC-Qdrsb_” outputId=“da45a015-b834-4039-aafa-76451e7af471” # Testing code… def mamba_torch(x, a, b, c, delta): y = [] h = 0 a_, b_ = discretize(a, b, delta) for k in range(x.shape[-1]): h = a_[…, k] * h + b_[…, k] * x[…, k] y.append((c[…, k] * h).sum(1)) return h, torch.stack(y, -1) y_ = full_mamba_torch(X, A, B, C, Delta) def L(x, a, b, c, delta): return mamba_torch(x, a, b, c, delta)[1].sum()
Ba = 4 D = 8 A = ones(Ba, N, 1, 1) B = ones(Ba, N, 1, K) C = ones(Ba, N, 1, K) Delta = ones(Ba, 1, D, K) X = ones(Ba, 1, D, K)
x_ = X.clone() h, y_ = mamba_torch(x_, A, B, C, Delta) g = torch.func.grad(L, tuple(range(5))) dx_, da_, db_, dc_, ddelta_ = g(x_, A, B, C, Delta) y_.shape
```python id="rjv2GDWq3P4w"
Here are all the shapes. (Note that what they call L we call KT)
```python id=“jftlDWcSIUkR” @triton.jit def mamba_tt(X, dX, A, dA, B, dB, C, dC, Delta, dDelta, H_0, dH_0, Y, dY, H, dH, back:tl.constexpr, KT: tl.constexpr, K: tl.constexpr, D:tl.constexpr, L: tl.constexpr, N: tl.constexpr, Ba: tl.constexpr): pid = tl.program_id(0) bid = tl.program_id(1) kid = pid * K Ks = tl.arange(0, K)[None, None, :] # 1 x 1 x K Ds = tl.arange(0, D)[None, :, None] # 1 x D x 1 Ns = tl.arange(0, N)[:, None, None] # N x 1 x 1
a = tl.load(A + bid*N + Ns) # N x 1 x 1
b = tl.load(B + bid*N*KT + Ns*KT + Ks + kid) # N x 1 x K
c = tl.load(C + bid*N*KT + Ns*KT + Ks + kid) # N x 1 x K
delta = tl.load(Delta + bid*D*KT + Ds*KT + Ks + kid) # D x K
x = tl.load(X + bid * D * KT + Ds * KT + Ks + kid) # D x K
h_off = bid*N*D*L + Ns*D*L + Ds*L + Ks * 0 + pid
h2_0 = tl.load(H_0 + 1*Ba*N*D*L + h_off, Ks==0, 0) # N x D x K
# Compute
a_, b_ = discretize_tt(a, b, delta)
# Compute Forward
h1, h2 = abc_scan(a_, b_ * x, Ks*0+1.0, h2_0, dim=2)
y = tl.sum(c * h2, 0)
tl.store(Y + bid * D * KT + Ds * KT + Ks + kid, y[None])
# Make big
tl.store(H + 0*L*Ba*N*D + h_off, h1, Ks == K-1)
tl.store(H + 1*L*Ba*N*D + h_off, h2, Ks == K-1)
if back == 0: return
# Compute Backward
dy = tl.load(dY + bid * D * KT + Ds * KT + Ks + kid) # D x K
dh2_0 = tl.load(dH_0 + 1*L*Ba*N*D + h_off, Ks==0, 0) # N x D x 1
h1, dh = abc_scan(a_, c * dy, Ks*0+1.0, dh2_0, reversed=1, dim=2)
rh2 = roll(h2[:, :, None, :], K)
da_ = dh * rh2 # N x D x K
db_ = dh * x # N x D x K
dc = tl.sum(h2 * dy, 1) # N x K
# # Uncompute
da, db, ddelta = discretize_back(a, b, delta, da_, db_)
# # Save
tl.store(dX + bid * D * KT + Ds * KT + Ks + kid, tl.sum(b_ * dh, 0)[None])
tl.store(dA + bid*N + Ns, tl.sum(tl.sum(da, 1), 1)[:, None, None])
tl.store(dB + bid*N*KT + Ns*KT + Ks + kid, tl.sum(db, 1)[:, None, :])
tl.store(dC + bid*N*KT + Ns*KT + Ks + kid, tl.sum(h2 * dy, 1)[:, None, :])
tl.store(dDelta + bid*D*KT + Ds*KT + Ks + kid, tl.sum(ddelta, 0)[None, :, :])
tl.store(dH + 0*L*Ba*N*D + h_off, h1, Ks == 0)
tl.store(dH + 1*L*Ba*N*D + h_off, h2, Ks == 0)
<!-- #region id="_qMQMdMp3ZmU" -->
Testing
<!-- #endregion -->
```python id="BF_WF-s9tOch"
dA = ones(Ba, N, 1, 1)
dB = ones(Ba, N, 1, K)
dC = ones(Ba, N, 1, K)
dDelta = ones(Ba, 1, D, K)
Y, dY = [ones(Ba, D, K) for _ in range(2)]
H, dH = [zeros(2, 2, Ba, N, D, BLOCKS) for _ in range(2)]
mamba_tt[(1, Ba)](X, dx, A, dA, B, dB, C, dC, Delta, dDelta, H[0], dH[0], Y, dY, H[0], dH[0],
back=1, K=K, KT=K, D=D, N=N, L=BLOCKS, Ba=Ba)
assert torch.allclose(Y, y_), f"{Y}, {y_}"
assert torch.allclose(dC, dc_), f"{dC} {dc_}"
assert torch.allclose(dDelta, ddelta_), f"{dDelta} {ddelta_}"
assert torch.allclose(dB, db_), f"{dB} {db_}"
assert torch.allclose(dA, da_), f"{dA} {da_}"
```python id=“fF6_wFnlN0Ly” a = ones(Ba, N, 1, 1) b = ones(Ba, N, 1, SEQLEN) c = torch.rand((Ba, N, 1, SEQLEN)).cuda().abs() delta = torch.rand((Ba, 1, D, SEQLEN)).cuda().abs() * 0.01
def mamba(x, a, b, c, delta): Ba = x.shape[0] N = a.shape[1] D = delta.shape[2] SEQLEN = x.shape[-1] dx = torch.zeros_like(x) da = torch.zeros_like(a) db = torch.zeros_like(b) dc = torch.zeros_like(c) ddelta = torch.zeros_like(delta) y, dy = [ones(Ba, D, SEQLEN) for _ in range(2)] h, dh = [zeros(2, 2, Ba, N, D, BLOCKS) for _ in range(2)]
mamba_tt[(BLOCKS, Ba)](x, dx, a, da, b, db, c, dc, delta, ddelta, h[0], dh[0], y, dy, h[0], dh[0], back=1, KT=SEQLEN, K=K, D=D, N=N, L=BLOCKS, Ba=Ba)
reduce(h, False, Ba * N * D)
reduce(dh, True, Ba * N * D)
mamba_tt[(BLOCKS, Ba)](x, dx, a, da, b, db, c, dc, delta, ddelta, h[1], dh[1], y, dy, h[1], dh[1], back=1, KT=SEQLEN, K=K, D=D, N=N, L=BLOCKS, Ba=Ba)
return y
x = torch.rand_like(x) y = mamba(x, a, b, c, delta) , y = mamba_torch(x, a, b, c, delta) assert torch.allclose(y[0, 0], y_[0, 0]), f“{y[0,0]} {y_[0,0]}” ```
(to finish)